// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"

#include <assert.h>
#include <cuda_runtime.h>

#define BLOCK_SIZE 512
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
 * N: number of tensors
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
 * tl[3] : fp16 weights (if appropriate)
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
 * wd_after_momentum : apply weight decay _after_ momentum instead of before
 **/
template <int N, typename T_grad, typename T_weight>
struct SGDFunctor
{
    __device__ __forceinline__ void operator()(
        int chunk_size,
        volatile int *noop_gmem,
        TensorListMetadata<N> &tl,
        float wd,
        float momentum,
        float dampening,
        float lr,
        bool nesterov,
        bool first_run,
        bool wd_after_momentum,
        float scale)
    {
        // Early exit if we don't need to do anything
        if (*noop_gmem)
            return;

        int tensor_loc = tl.block_to_tensor[blockIdx.x];
        int chunk_idx = tl.block_to_chunk[blockIdx.x];
        int n = tl.sizes[tensor_loc];

        T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
        grad_in += chunk_idx * chunk_size;

        T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
        weight_in += chunk_idx * chunk_size;

        T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
        mom_in += chunk_idx * chunk_size;

        at::Half *model_weights_out = nullptr;
        if (N == 4)
        {
            model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
            model_weights_out += chunk_idx * chunk_size;
        }

        n -= chunk_idx * chunk_size;

        // Non-divergent exit condition for the __syncthreads
        float incoming_grads[ILP];
        float incoming_weights[ILP];
        float incoming_moms[ILP];
        for (int i_start = 0;
             i_start < n && i_start < chunk_size;
             i_start += blockDim.x * ILP)
        {
#pragma unroll
            for (int ii = 0; ii < ILP; ii++)
            {
                incoming_grads[ii] = 0;
                incoming_weights[ii] = 0;
                incoming_moms[ii] = 0;
                int i = i_start + threadIdx.x + ii * blockDim.x;
                if (i < n && i < chunk_size)
                {
                    incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
                    incoming_weights[ii] = static_cast<float>(weight_in[i]);
                    incoming_moms[ii] = static_cast<float>(mom_in[i]);
                }
            }

// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
            for (int ii = 0; ii < ILP; ii++)
            {
                int i = i_start + threadIdx.x + ii * blockDim.x;
                if (i < n && i < chunk_size)
                {
                    // apply weight decay before momentum if necessary
                    if (wd != 0.f && !wd_after_momentum)
                        incoming_grads[ii] += wd * incoming_weights[ii];

                    if (momentum != 0.f)
                    {
                        if (!first_run)
                            incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
                        else // initialize momentums to current incoming grads
                            incoming_moms[ii] = incoming_grads[ii];

                        if (nesterov)
                            incoming_grads[ii] += momentum * incoming_moms[ii];
                        else
                            incoming_grads[ii] = incoming_moms[ii];
                    }

                    // Apply WD after momentum if desired
                    if (wd != 0.f && wd_after_momentum)
                        incoming_grads[ii] += wd * incoming_weights[ii];

                    // adjust the weight and write out
                    weight_in[i] += (-lr * incoming_grads[ii]);

                    // if necessary, write out an fp16 copy of the weights
                    if (N == 4)
                        model_weights_out[i] = static_cast<at::Half>(weight_in[i]);

                    // also write out the new momentum
                    if (momentum != 0.f)
                        mom_in[i] = incoming_moms[ii];
                }
            }
        }
    }
};

void multi_tensor_sgd_cuda(
    int chunk_size,
    at::Tensor noop_flag,
    std::vector<std::vector<at::Tensor>> tensor_lists,
    float wd,
    float momentum,
    float dampening,
    float lr,
    bool nesterov,
    bool first_run,
    bool wd_after_momentum,
    float scale)
{
    auto num_tensors = tensor_lists.size();
    auto grad_type = tensor_lists[0][0].scalar_type();
    auto weight_type = tensor_lists[1][0].scalar_type();

    if (num_tensors == 4)
        for (int i = 0; i < tensor_lists[3].size(); i++)
            TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
                        "Additional output tensors should always be fp16.");

    TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");

    // We have 3 possibilities to handle here, in terms of
    // grad_type, param_type, momentum_type, requires_fp16_copy
    // 1. fp16, fp16, fp16, No
    // 2. fp32, fp32, fp32, No
    // 3. fp16, fp32, fp32, Yes
    // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
    // It's easier to hardcode these possibilities than to use
    // switches etc. to handle the cross-product of cases where
    // we don't want the majority of them.

    // Case 1. fp16, fp16, fp16, No
    if (grad_type == at::ScalarType::Half &&
        weight_type == at::ScalarType::Half &&
        num_tensors == 3)
    {
        multi_tensor_apply<3>(
            BLOCK_SIZE,
            chunk_size,
            noop_flag,
            tensor_lists,
            SGDFunctor<3, at::Half, at::Half>(),
            wd,
            momentum,
            dampening,
            lr,
            nesterov,
            first_run,
            wd_after_momentum,
            scale);
    }
    // Case 2. fp16, fp32, fp32, No
    // else if (grad_type == at::ScalarType::Half &&
    //          weight_type == at::ScalarType::Float &&
    //          num_tensors == 3) {
    //   multi_tensor_apply<3>(
    //       BLOCK_SIZE,
    //       chunk_size,
    //       noop_flag,
    //       tensor_lists,
    //       SGDFunctor<3, at::Half, float>(),
    //       wd,
    //       momentum,
    //       dampening,
    //       lr,
    //       nesterov,
    //       first_run,
    //       wd_after_momentum);
    // }
    // Case 2. fp32, fp32, fp32, No
    else if (grad_type == at::ScalarType::Float &&
             weight_type == at::ScalarType::Float &&
             num_tensors == 3)
    {
        multi_tensor_apply<3>(
            BLOCK_SIZE,
            chunk_size,
            noop_flag,
            tensor_lists,
            SGDFunctor<3, float, float>(),
            wd,
            momentum,
            dampening,
            lr,
            nesterov,
            first_run,
            wd_after_momentum,
            scale);
    }
    // Case 3. fp16, fp32, fp32, Yes
    else if (grad_type == at::ScalarType::Half &&
             weight_type == at::ScalarType::Float &&
             num_tensors == 4)
    {
        multi_tensor_apply<4>(
            BLOCK_SIZE,
            chunk_size,
            noop_flag,
            tensor_lists,
            SGDFunctor<4, at::Half, float>(),
            wd,
            momentum,
            dampening,
            lr,
            nesterov,
            first_run,
            wd_after_momentum,
            scale);
    }
    // Case 4. fp32, fp32, fp32, Yes
    else if (grad_type == at::ScalarType::Float &&
             weight_type == at::ScalarType::Float &&
             num_tensors == 4)
    {
        multi_tensor_apply<4>(
            BLOCK_SIZE,
            chunk_size,
            noop_flag,
            tensor_lists,
            SGDFunctor<4, float, float>(),
            wd,
            momentum,
            dampening,
            lr,
            nesterov,
            first_run,
            wd_after_momentum,
            scale);
    }
    else
    {
        AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
                 "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
    }

    AT_CUDA_CHECK(cudaGetLastError());
}